{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca1a3483",
   "metadata": {},
   "source": [
    "## Description：\n",
    "这里跑通SDM的baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ae4a43d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:\n",
      "DeepCTR version 0.9.0 detected. Your version is 0.8.2.\n",
      "Use `pip install -U deepctr` to upgrade.Changelog: https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import random\n",
    "from datetime import datetime\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "# 从utils里面导入函数\n",
    "from utils import get_data_set, gen_model_input, train_sdm_model\n",
    "from utils import get_embeddings, get_sdm_recall_res\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8eedd2fb",
   "metadata": {},
   "source": [
    "## 导入数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cffd1cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../data_process'\n",
    "data = pd.read_csv(os.path.join(data_path, 'train_data.csv'), index_col=0, parse_dates=['expo_time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc803379",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择出需要用到的列\n",
    "use_cols = ['user_id', 'article_id', 'expo_time', 'net_status', 'exop_position', 'duration', 'city', 'age', 'gender', 'click', 'cat_1', 'cat_2']\n",
    "data_new = data[use_cols]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab05d27f",
   "metadata": {},
   "source": [
    "## 划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aa50918c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 按照用户分组，然后把最后一个item拿出来\n",
    "click_df = data_new[data_new['click']==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "91ed6bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hist_and_last_click(all_click):\n",
    "    all_click = all_click.sort_values(by=['user_id', 'expo_time'])\n",
    "    click_last_df = all_click.groupby('user_id').tail(1)\n",
    "    \n",
    "    # 如果用户只有一个点击，hist为空了，会导致训练的时候这个用户不可见，此时默认泄露一下\n",
    "    def hist_func(user_df):\n",
    "        if len(user_df) == 1:\n",
    "            return user_df\n",
    "        else:\n",
    "            return user_df[:-1]\n",
    "\n",
    "    click_hist_df = all_click.groupby('user_id').apply(hist_func).reset_index(drop=True)\n",
    "\n",
    "    return click_hist_df, click_last_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cff45bf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_click_hist_df, user_click_last_df = get_hist_and_last_click(click_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a316035f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>article_id</th>\n",
       "      <th>expo_time</th>\n",
       "      <th>net_status</th>\n",
       "      <th>exop_position</th>\n",
       "      <th>duration</th>\n",
       "      <th>city</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>click</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>cat_2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>17340</td>\n",
       "      <td>464481478</td>\n",
       "      <td>2021-06-30 20:34:47</td>\n",
       "      <td>2</td>\n",
       "      <td>21</td>\n",
       "      <td>27</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_30_39</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>汽车</td>\n",
       "      <td>汽车/买车</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>17340</td>\n",
       "      <td>465148736</td>\n",
       "      <td>2021-07-02 19:35:03</td>\n",
       "      <td>5</td>\n",
       "      <td>23</td>\n",
       "      <td>49</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_30_39</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>旅游</td>\n",
       "      <td>旅游/旅游资讯</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>17340</td>\n",
       "      <td>464707540</td>\n",
       "      <td>2021-07-02 19:47:06</td>\n",
       "      <td>5</td>\n",
       "      <td>25</td>\n",
       "      <td>174</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_30_39</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>健康</td>\n",
       "      <td>健康/疾病防护治疗及西医用药</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>17340</td>\n",
       "      <td>464993414</td>\n",
       "      <td>2021-07-02 19:47:06</td>\n",
       "      <td>5</td>\n",
       "      <td>27</td>\n",
       "      <td>11</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_30_39</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>搞笑</td>\n",
       "      <td>搞笑/段子</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17340</td>\n",
       "      <td>465115022</td>\n",
       "      <td>2021-07-02 20:01:34</td>\n",
       "      <td>5</td>\n",
       "      <td>41</td>\n",
       "      <td>14</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_30_39</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>房产</td>\n",
       "      <td>房产/买房卖房</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  article_id           expo_time  net_status  exop_position  \\\n",
       "0    17340   464481478 2021-06-30 20:34:47           2             21   \n",
       "1    17340   465148736 2021-07-02 19:35:03           5             23   \n",
       "2    17340   464707540 2021-07-02 19:47:06           5             25   \n",
       "3    17340   464993414 2021-07-02 19:47:06           5             27   \n",
       "4    17340   465115022 2021-07-02 20:01:34           5             41   \n",
       "\n",
       "   duration city      age gender  click cat_1           cat_2  \n",
       "0        27   上海  A_30_39   male      1    汽车           汽车/买车  \n",
       "1        49   上海  A_30_39   male      1    旅游         旅游/旅游资讯  \n",
       "2       174   上海  A_30_39   male      1    健康  健康/疾病防护治疗及西医用药  \n",
       "3        11   上海  A_30_39   male      1    搞笑           搞笑/段子  \n",
       "4        14   上海  A_30_39   male      1    房产         房产/买房卖房  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "user_click_hist_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e5fc454",
   "metadata": {},
   "source": [
    "## SDM召回"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "97aa16a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sdm_recall(data, topk=200, embedding_dim=32, SEQ_LEN_short=5, SEQ_LEN_prefer=50,\n",
    "                      batch_size=64, epochs=1, verbose=1, validation_split=0.0):\n",
    "    \"\"\"通过SDM模型，计算用户向量和文章向量\n",
    "    param: data: 用户日志数据\n",
    "    topk: 对于每个用户，召回多少篇文章\n",
    "    \"\"\"\n",
    "    user_id_raw = data[['user_id']].drop_duplicates('user_id')\n",
    "    doc_id_raw = data[['article_id']].drop_duplicates('article_id')\n",
    "    \n",
    "    # 类别数据编码   \n",
    "    base_features = ['user_id', 'article_id', 'city', 'age', 'gender', 'cat_1', 'cat_2']\n",
    "    feature_max_idx = {}\n",
    "    for f in base_features:\n",
    "        lbe = LabelEncoder()\n",
    "        data[f] = lbe.fit_transform(data[f]) + 1\n",
    "        feature_max_idx[f] = data[f].max() + 1\n",
    "        \n",
    "    # 构建用户id词典和doc的id词典，方便从用户idx找到原始的id\n",
    "    user_id_enc = data[['user_id']].drop_duplicates('user_id')\n",
    "    doc_id_enc = data[['article_id']].drop_duplicates('article_id')\n",
    "    user_idx_2_rawid = dict(zip(user_id_enc['user_id'], user_id_raw['user_id']))\n",
    "    doc_idx_2_rawid = dict(zip(doc_id_enc['article_id'], doc_id_raw['article_id']))\n",
    "    \n",
    "    user_profile = data[['user_id', 'gender', 'age', 'city']].drop_duplicates('user_id')\n",
    "    item_profile = data[['article_id']].drop_duplicates('article_id')\n",
    "    user_profile.set_index('user_id', inplace=True)\n",
    "    \n",
    "    train_set, test_set = get_data_set(user_click_hist_df, seq_short_len=SEQ_LEN_short, seq_prefer_len=SEQ_LEN_prefer)\n",
    "    \n",
    "    train_model_input, train_label = gen_model_input(train_set, user_profile, SEQ_LEN_short, SEQ_LEN_prefer)\n",
    "    test_model_input, test_label = gen_model_input(test_set, user_profile, SEQ_LEN_short, SEQ_LEN_prefer)\n",
    "       \n",
    "    # 构建模型并完成训练\n",
    "    model = train_sdm_model(train_model_input, train_label, embedding_dim, feature_max_idx, SEQ_LEN_short, SEQ_LEN_prefer, batch_size, epochs, verbose, validation_split)\n",
    "    \n",
    "    # 获得用户embedding和doc的embedding， 并进行保存\n",
    "    user_embs, doc_embs = get_embeddings(model, test_model_input, user_idx_2_rawid, doc_idx_2_rawid)\n",
    "    \n",
    "    # 对每个用户，拿到召回结果并返回回来\n",
    "    user_recall_items_dict = get_sdm_recall_res(user_embs, doc_embs, user_idx_2_rawid, doc_idx_2_rawid, topk)\n",
    "    \n",
    "    return user_recall_items_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3d1e905c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████| 20000/20000 [00:08<00:00, 2338.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\keras\\initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\keras\\initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From E:\\Jupyter Notebook\\推荐系统\\fun-rec-tmp\\SDM模型\\SDM.py:156: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From E:\\Jupyter Notebook\\推荐系统\\fun-rec-tmp\\SDM模型\\SDM.py:156: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From E:\\Jupyter Notebook\\推荐系统\\fun-rec-tmp\\SDM模型\\SDM.py:184: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From E:\\Jupyter Notebook\\推荐系统\\fun-rec-tmp\\SDM模型\\SDM.py:184: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\autograph\\impl\\api.py:330: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\autograph\\impl\\api.py:330: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\rnn_cell_impl.py:735: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\rnn_cell_impl.py:735: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\rnn_cell_impl.py:739: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\ops\\rnn_cell_impl.py:739: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\keras\\initializers.py:94: calling TruncatedNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\keras\\initializers.py:94: calling TruncatedNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 526056 samples\n",
      "526056/526056 [==============================] - 1315s 3ms/sample - loss: 0.3605\n"
     ]
    }
   ],
   "source": [
    "user_recall_doc_dict = sdm_recall(user_click_hist_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d701e611",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19999"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(user_recall_doc_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aacbbe9",
   "metadata": {},
   "source": [
    "## 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "01e81b96",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 依次评估召回的前10, 20, 30, 40, 50个文章中的击中率\n",
    "def metrics_recall(user_recall_items_dict, trn_last_click_df, topk=100):\n",
    "    last_click_item_dict = dict(zip(trn_last_click_df['user_id'], trn_last_click_df['article_id']))\n",
    "    user_num = len(user_recall_items_dict)\n",
    "    \n",
    "    for k in range(50, topk+1, 50):\n",
    "        hit_num = 0\n",
    "        for user, item_list in user_recall_items_dict.items():\n",
    "            if user in last_click_item_dict:\n",
    "                # 获取前k个召回的结果\n",
    "                tmp_recall_items = [x[0] for x in user_recall_items_dict[user][:k]]\n",
    "                if last_click_item_dict[user] in set(tmp_recall_items):\n",
    "                    hit_num += 1\n",
    "        \n",
    "        hit_rate = round(hit_num * 1.0 / user_num, 5)\n",
    "        print(' topk: ', k, ' : ', 'hit_num: ', hit_num, 'hit_rate: ', hit_rate, 'user_num : ', user_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a21b0e7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " topk:  50  :  hit_num:  28 hit_rate:  0.0014 user_num :  19999\n",
      " topk:  100  :  hit_num:  51 hit_rate:  0.00255 user_num :  19999\n",
      " topk:  150  :  hit_num:  70 hit_rate:  0.0035 user_num :  19999\n",
      " topk:  200  :  hit_num:  90 hit_rate:  0.0045 user_num :  19999\n"
     ]
    }
   ],
   "source": [
    "metrics_recall(user_recall_doc_dict, user_click_last_df, topk=200)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
